{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    },
    "colab": {
      "name": "4-6_OpenPose_training.ipynb",
      "provenance": [],
      "collapsed_sections": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eYtEkX621-OC"
      },
      "source": [
        "# 4.6 学習と検証の実施\n",
        "\n",
        "- 本ファイルでは、OpenPoseの学習と検証の実施を行います。AWSのGPUマシンで計算します。\n",
        "- p2.xlargeで45分ほどかかります。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LZtHjCjD1-OH"
      },
      "source": [
        "# 学習目標\n",
        "\n",
        "1.\tOpenPoseの学習を実装できるようになる"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xAm0G0mW1-OI"
      },
      "source": [
        "# 事前準備\n",
        "\n",
        "- これまでの章で実装したクラスと関数をフォルダ「utils」内に用意しています\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X7WUHRm71-OJ"
      },
      "source": [
        "# パッケージのimport\n",
        "import random\n",
        "import math\n",
        "import time\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.utils.data as data\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UhDCLEgj1-OK"
      },
      "source": [
        "# 初期設定\n",
        "# Setup seeds\n",
        "torch.manual_seed(1234)\n",
        "np.random.seed(1234)\n",
        "random.seed(1234)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "utz_Cyy21-OL"
      },
      "source": [
        "# DataLoader作成"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OvEffoMb1-OL"
      },
      "source": [
        "from utils.dataloader import make_datapath_list, DataTransform, COCOkeypointsDataset\n",
        "\n",
        "# MS COCOのファイルパスリスト作成\n",
        "train_img_list, train_mask_list, val_img_list, val_mask_list, train_meta_list, val_meta_list = make_datapath_list(\n",
        "    rootpath=\"./data/\")\n",
        "\n",
        "# Dataset作成\n",
        "# 本書ではデータ量の問題から、trainをval_listで作成している点に注意\n",
        "train_dataset = COCOkeypointsDataset(\n",
        "    val_img_list, val_mask_list, val_meta_list, phase=\"train\", transform=DataTransform())\n",
        "\n",
        "# 今回は簡易な学習とし検証データは作成しない\n",
        "# val_dataset = CocokeypointsDataset(val_img_list, val_mask_list, val_meta_list, phase=\"val\", transform=DataTransform())\n",
        "\n",
        "# DataLoader作成\n",
        "batch_size = 32\n",
        "\n",
        "train_dataloader = data.DataLoader(\n",
        "    train_dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "# val_dataloader = data.DataLoader(\n",
        "#    val_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "# 辞書型変数にまとめる\n",
        "# dataloaders_dict = {\"train\": train_dataloader, \"val\": val_dataloader}\n",
        "dataloaders_dict = {\"train\": train_dataloader, \"val\": None}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g53aOeuY1-ON"
      },
      "source": [
        "# ネットワークモデル作成"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oh1g-qsZ1-OO"
      },
      "source": [
        "from utils.openpose_net import OpenPoseNet\n",
        "net = OpenPoseNet()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iehMlbZl1-OQ"
      },
      "source": [
        "# 損失関数を定義"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5cy4COD81-OQ"
      },
      "source": [
        "# 損失関数の設定\n",
        "class OpenPoseLoss(nn.Module):\n",
        "    \"\"\"OpenPoseの損失関数のクラスです。\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        super(OpenPoseLoss, self).__init__()\n",
        "\n",
        "    def forward(self, saved_for_loss, heatmap_target, heat_mask, paf_target, paf_mask):\n",
        "        \"\"\"\n",
        "        損失関数の計算。\n",
        "\n",
        "        Parameters\n",
        "        ----------\n",
        "        saved_for_loss : OpenPoseNetの出力(リスト)\n",
        "\n",
        "        heatmap_target : [num_batch, 19, 46, 46]\n",
        "            正解の部位のアノテーション情報\n",
        "\n",
        "        heatmap_mask : [num_batch, 19, 46, 46]\n",
        "            heatmap画像のmask\n",
        "\n",
        "        paf_target : [num_batch, 38, 46, 46]\n",
        "            正解のPAFのアノテーション情報\n",
        "\n",
        "        paf_mask : [num_batch, 38, 46, 46]\n",
        "            PAF画像のmask\n",
        "\n",
        "        Returns\n",
        "        -------\n",
        "        loss : テンソル\n",
        "            損失の値\n",
        "        \"\"\"\n",
        "\n",
        "        total_loss = 0\n",
        "        # ステージごとに計算します\n",
        "        for j in range(6):\n",
        "\n",
        "            # PAFsとheatmapsにおいて、マスクされている部分（paf_mask=0など）は無視させる\n",
        "            # PAFs\n",
        "            pred1 = saved_for_loss[2 * j] * paf_mask\n",
        "            gt1 = paf_target.float() * paf_mask\n",
        "\n",
        "            # heatmaps\n",
        "            pred2 = saved_for_loss[2 * j + 1] * heat_mask\n",
        "            gt2 = heatmap_target.float()*heat_mask\n",
        "\n",
        "            total_loss += F.mse_loss(pred1, gt1, reduction='mean') + \\\n",
        "                F.mse_loss(pred2, gt2, reduction='mean')\n",
        "\n",
        "        return total_loss\n",
        "\n",
        "\n",
        "criterion = OpenPoseLoss()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HWDPFIh-1-OR"
      },
      "source": [
        "# 最適化手法を設定"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7dWlCWqG1-OR"
      },
      "source": [
        "optimizer = optim.SGD(net.parameters(), lr=1e-2,\n",
        "                      momentum=0.9,\n",
        "                      weight_decay=0.0001)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aNgpvEju1-OS"
      },
      "source": [
        "# 学習を実施"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6acs014S1-OS"
      },
      "source": [
        "# モデルを学習させる関数を作成\n",
        "\n",
        "\n",
        "def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):\n",
        "\n",
        "    # GPUが使えるかを確認\n",
        "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "    print(\"使用デバイス：\", device)\n",
        "\n",
        "    # ネットワークをGPUへ\n",
        "    net.to(device)\n",
        "\n",
        "    # ネットワークがある程度固定であれば、高速化させる\n",
        "    torch.backends.cudnn.benchmark = True\n",
        "\n",
        "    # 画像の枚数\n",
        "    num_train_imgs = len(dataloaders_dict[\"train\"].dataset)\n",
        "    batch_size = dataloaders_dict[\"train\"].batch_size\n",
        "\n",
        "    # イテレーションカウンタをセット\n",
        "    iteration = 1\n",
        "\n",
        "    # epochのループ\n",
        "    for epoch in range(num_epochs):\n",
        "\n",
        "        # 開始時刻を保存\n",
        "        t_epoch_start = time.time()\n",
        "        t_iter_start = time.time()\n",
        "        epoch_train_loss = 0.0  # epochの損失和\n",
        "        epoch_val_loss = 0.0  # epochの損失和\n",
        "\n",
        "        print('-------------')\n",
        "        print('Epoch {}/{}'.format(epoch+1, num_epochs))\n",
        "        print('-------------')\n",
        "\n",
        "        # epochごとの訓練と検証のループ\n",
        "        for phase in ['train', 'val']:\n",
        "            if phase == 'train':\n",
        "                net.train()  # モデルを訓練モードに\n",
        "                optimizer.zero_grad()\n",
        "                print('（train）')\n",
        "\n",
        "            # 今回は検証はスキップ\n",
        "            else:\n",
        "                continue\n",
        "                # net.eval()   # モデルを検証モードに\n",
        "                # print('-------------')\n",
        "                # print('（val）')\n",
        "\n",
        "            # データローダーからminibatchずつ取り出すループ\n",
        "            for imges, heatmap_target, heat_mask, paf_target, paf_mask in dataloaders_dict[phase]:\n",
        "                # ミニバッチがサイズが1だと、バッチノーマライゼーションでエラーになるのでさける\n",
        "                # issue #186より不要なのでコメントアウト\n",
        "                # if imges.size()[0] == 1:\n",
        "                #     continue\n",
        "\n",
        "                # GPUが使えるならGPUにデータを送る\n",
        "                imges = imges.to(device)\n",
        "                heatmap_target = heatmap_target.to(device)\n",
        "                heat_mask = heat_mask.to(device)\n",
        "                paf_target = paf_target.to(device)\n",
        "                paf_mask = paf_mask.to(device)\n",
        "\n",
        "                # optimizerを初期化\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # 順伝搬（forward）計算\n",
        "                with torch.set_grad_enabled(phase == 'train'):\n",
        "                    # (out6_1, out6_2)は使わないので _ で代替\n",
        "                    _, saved_for_loss = net(imges)\n",
        "\n",
        "                    loss = criterion(saved_for_loss, heatmap_target,\n",
        "                                     heat_mask, paf_target, paf_mask)\n",
        "                    del saved_for_loss\n",
        "                    # 訓練時はバックプロパゲーション\n",
        "                    if phase == 'train':\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                        if (iteration % 10 == 0):  # 10iterに1度、lossを表示\n",
        "                            t_iter_finish = time.time()\n",
        "                            duration = t_iter_finish - t_iter_start\n",
        "                            print('イテレーション {} || Loss: {:.4f} || 10iter: {:.4f} sec.'.format(\n",
        "                                iteration, loss.item()/batch_size, duration))\n",
        "                            t_iter_start = time.time()\n",
        "\n",
        "                        epoch_train_loss += loss.item()\n",
        "                        iteration += 1\n",
        "\n",
        "                    # 検証時\n",
        "                    # else:\n",
        "                        #epoch_val_loss += loss.item()\n",
        "\n",
        "        # epochのphaseごとのlossと正解率\n",
        "        t_epoch_finish = time.time()\n",
        "        print('-------------')\n",
        "        print('epoch {} || Epoch_TRAIN_Loss:{:.4f} ||Epoch_VAL_Loss:{:.4f}'.format(\n",
        "            epoch+1, epoch_train_loss/num_train_imgs, 0))\n",
        "        print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))\n",
        "        t_epoch_start = time.time()\n",
        "\n",
        "    # 最後のネットワークを保存する\n",
        "    torch.save(net.state_dict(), 'weights/openpose_net_' +\n",
        "               str(epoch+1) + '.pth')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KDz3LPQr1-OT",
        "outputId": "92e793d1-5755-4e0f-cd91-8a2ea68e4ebe"
      },
      "source": [
        "# 学習・検証を実行する\n",
        "num_epochs = 2\n",
        "train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "使用デバイス： cuda:0\n",
            "-------------\n",
            "Epoch 1/2\n",
            "-------------\n",
            "（train）\n",
            "イテレーション 10 || Loss: 0.0094 || 10iter: 113.7127 sec.\n",
            "イテレーション 20 || Loss: 0.0082 || 10iter: 90.4145 sec.\n",
            "イテレーション 30 || Loss: 0.0069 || 10iter: 88.4890 sec.\n",
            "イテレーション 40 || Loss: 0.0058 || 10iter: 90.9961 sec.\n",
            "イテレーション 50 || Loss: 0.0050 || 10iter: 90.8274 sec.\n",
            "イテレーション 60 || Loss: 0.0042 || 10iter: 89.7553 sec.\n",
            "イテレーション 70 || Loss: 0.0038 || 10iter: 91.1155 sec.\n",
            "イテレーション 80 || Loss: 0.0031 || 10iter: 91.3307 sec.\n",
            "イテレーション 90 || Loss: 0.0027 || 10iter: 91.7214 sec.\n",
            "イテレーション 100 || Loss: 0.0026 || 10iter: 92.2645 sec.\n",
            "イテレーション 110 || Loss: 0.0023 || 10iter: 91.7421 sec.\n",
            "イテレーション 120 || Loss: 0.0020 || 10iter: 90.7930 sec.\n",
            "イテレーション 130 || Loss: 0.0020 || 10iter: 91.3045 sec.\n",
            "イテレーション 140 || Loss: 0.0019 || 10iter: 91.6105 sec.\n",
            "イテレーション 150 || Loss: 0.0016 || 10iter: 90.2619 sec.\n",
            "-------------\n",
            "epoch 1 || Epoch_TRAIN_Loss:0.0043 ||Epoch_VAL_Loss:0.0000\n",
            "timer:  1462.0789 sec.\n",
            "-------------\n",
            "Epoch 2/2\n",
            "-------------\n",
            "（train）\n",
            "イテレーション 160 || Loss: 0.0017 || 10iter: 64.3399 sec.\n",
            "イテレーション 170 || Loss: 0.0017 || 10iter: 91.2324 sec.\n",
            "イテレーション 180 || Loss: 0.0015 || 10iter: 92.3138 sec.\n",
            "イテレーション 190 || Loss: 0.0015 || 10iter: 90.3904 sec.\n",
            "イテレーション 200 || Loss: 0.0015 || 10iter: 90.9617 sec.\n",
            "イテレーション 210 || Loss: 0.0016 || 10iter: 91.2119 sec.\n",
            "イテレーション 220 || Loss: 0.0014 || 10iter: 90.6868 sec.\n",
            "イテレーション 230 || Loss: 0.0016 || 10iter: 90.8710 sec.\n",
            "イテレーション 240 || Loss: 0.0017 || 10iter: 90.3973 sec.\n",
            "イテレーション 250 || Loss: 0.0014 || 10iter: 90.8158 sec.\n",
            "イテレーション 260 || Loss: 0.0012 || 10iter: 92.8508 sec.\n",
            "イテレーション 270 || Loss: 0.0012 || 10iter: 91.9698 sec.\n",
            "イテレーション 280 || Loss: 0.0015 || 10iter: 90.8905 sec.\n",
            "イテレーション 290 || Loss: 0.0011 || 10iter: 91.2742 sec.\n",
            "イテレーション 300 || Loss: 0.0012 || 10iter: 91.0789 sec.\n",
            "-------------\n",
            "epoch 2 || Epoch_TRAIN_Loss:0.0015 ||Epoch_VAL_Loss:0.0000\n",
            "timer:  1437.0403 sec.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bsEvmpMH1-OU"
      },
      "source": [
        "以上"
      ]
    }
  ]
}